from sklearn.preprocessing import StandardScaler
import stan
from xbcausalforest import XBCF
import numpy as np
def ComputePosterior(X, A, Y, paramlist,  trt_price=0):
    # Data is a dataframe with X,A,Y,EA,EY
    # first standardize X
    scaler = StandardScaler()
    scaler.fit(X)
    X_scale = scaler.transform(X)
    if Y.dtype == int:
        with open("Simulation_run/GP_binary.stan") as f:
            GP_stan = ''.join(f.readlines())
    else:
        with open("Simulation_run/GP.stan") as f:
            GP_stan = ''.join(f.readlines())
    GP_data = {
        "N": X.shape[0], 'k': X.shape[1], "y": Y, 'x': X_scale, 'A': A,
        'length_scale_baseline': paramlist['length_scale'], "length_scale_treatment": paramlist['length_scale'],
        'sigma_baseline': paramlist['sigma_kernel'], 'sigma_treatment': paramlist['sigma_kernel']
    }
    posterior = stan.build(GP_stan, data=GP_data, random_seed=None)
    fit = posterior.sample(num_chains=2, num_samples=1000)
    trt_effect = fit['z1'] - fit['z0'] - trt_price
    EPI = trt_effect.mean(axis=1)
    NIP = (trt_effect<0).mean(axis=1)
    return EPI, NIP

def ComputePosteriorXBCF(X, A, Y, paramlist=None,  trt_price=0):
    scaler = StandardScaler()
    scaler.fit(X)
    X_scale = scaler.transform(X)
    NUM_TREES_PR = 200
    NUM_TREES_TRT = 100
    cf = XBCF(
        # model="Normal",
        parallel=True,
        num_sweeps=500,
        burnin=100,
        max_depth=250,
        num_trees_pr=NUM_TREES_PR,
        num_trees_trt=NUM_TREES_TRT,
        num_cutpoints=100,
        Nmin=1,
        # mtry_pr=X1.shape[1], # default 0 seems to be 'all'
        # mtry_trt=X.shape[1],
        tau_pr=0.2 * np.var(Y) / NUM_TREES_PR,  # 0.2 * np.var(y) / /NUM_TREES_PR,
        tau_trt=0.2 * np.var(Y) / NUM_TREES_TRT,  # 0.1 * np.var(y) / /NUM_TREES_TRT,
        alpha_pr=0.95,  # shrinkage (splitting probability)
        beta_pr=2,  # shrinkage (tree depth)
        alpha_trt=0.25,  # shrinkage for treatment part
        beta_trt=3,
        p_categorical_pr=0,
        p_categorical_trt=0,
        standardize_target=True,  # standardize y and unstandardize for prediction
    )
    X_prop = np.concatenate([X_scale,np.ones((X_scale.shape[0],1))/2],axis=1)
    cf.fit(X,X_prop,Y,A.astype('int32'))
    tau_posterior = cf.predict(X, return_mean=False)[:, cf.getParams()['burnin']:]
    trt_effect = tau_posterior - trt_price
    EPI = trt_effect.mean(axis=1)
    NIP = (trt_effect < 0).mean(axis=1)
    return EPI, NIP